import os
import logging
import argparse
from tqdm import tqdm, trange

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
# 设置绝对路径
import sys
sys.path.append("./")

from bert_finetune_cls.utils import init_logger, load_tokenizer, get_intent_labels, MODEL_CLASSES
# 日志对象初始化
logger = logging.getLogger(__name__)


def get_device(pred_config):
    """
    获得device参数
    :param pred_config:
    :return:
    """
    return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu"


def get_args(pred_config):
    """
    得到训练好后保存的模型参数
    :param pred_config:
    :return:
    """
    return torch.load(os.path.join(pred_config.model_dir, 'training_args.bin'))


def load_model(pred_config, args, device):
    """
    加载模型
    :param pred_config:
    :param args: 参数
    :param device: 配置
    :return:
    """
    # Check whether model exists
    if not os.path.exists(pred_config.model_dir):
        raise Exception("Model doesn't exists! Train first!")

    try:
        # 加载模型
        model = MODEL_CLASSES[args.model_type][1].from_pretrained(args.model_dir,
                                                                  args=args,
                                                                  intent_label_lst=get_intent_labels(args),
                                                                  )
        model.to(device)
        # 将模型固定不在训练
        model.eval()
        logger.info("***** Model Loaded *****")
    except:
        raise Exception("Some model files might be missing...")

    return model


def read_input_file(pred_config):
    """
    逐行读取输入文件
    :param pred_config:
    :return:
    """
    lines = []
    with open(pred_config.input_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            words = line.split()
            lines.append(words)

    return lines


def convert_input_file_to_tensor_dataset(lines,
                                         pred_config,
                                         args,
                                         tokenizer,
                                         cls_token_segment_id=0,
                                         pad_token_segment_id=0,
                                         sequence_a_segment_id=0,
                                         mask_padding_with_zero=True):
    """
    将原始输入数据转换成BERT模型需要的数据
    :param lines: 输入文件
    :param pred_config: 训练好的模型参数
    :param args: 参数
    :param tokenizer: 分词模型
    :param cls_token_segment_id: -100
    :param pad_token_segment_id: 0
    :param sequence_a_segment_id: 0
    :param mask_padding_with_zero: 0
    :return:
    """
    # 基于当前模型进行设置
    # [CLS]
    cls_token = tokenizer.cls_token
    # [SEP]
    sep_token = tokenizer.sep_token
    # [UNK]
    unk_token = tokenizer.unk_token
    # [PAD]
    pad_token_id = tokenizer.pad_token_id

    all_input_ids = []
    all_attention_mask = []
    all_token_type_ids = []
    # 循环读取每句话
    for words in lines:
        tokens = []
        # 循环读取每句话中的每个单词
        for word in words:
            # 对每个单词进行分词
            word_tokens = tokenizer.tokenize(word)
            # 处理错误编码的单词
            if not word_tokens:
                word_tokens = [unk_token]  # For handling the bad-encoded word
            tokens.extend(word_tokens)

        # Account for [CLS] and [SEP]
        special_tokens_count = 2
        # 如果句子长了就截断
        if len(tokens) > args.max_seq_len - special_tokens_count:
            tokens = tokens[: (args.max_seq_len - special_tokens_count)]

        # Add [SEP] token
        tokens += [sep_token]
        token_type_ids = [sequence_a_segment_id] * len(tokens)

        # Add [CLS] token
        tokens = [cls_token] + tokens
        token_type_ids = [cls_token_segment_id] + token_type_ids

        # 把tokens转化为bert词表中的id
        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # 长度补齐，保证长度满足最大序列长度
        # 需要填充序列的长度
        padding_length = args.max_seq_len - len(input_ids)
        # 输入样本序列在bert词表里的索引
        input_ids = input_ids + ([pad_token_id] * padding_length)
        # 注意力mask，padding的部分为0，其他为1
        attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
        # token_type_ids表示每个token属于句子1还是句子2
        token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        all_input_ids.append(input_ids)
        all_attention_mask.append(attention_mask)
        all_token_type_ids.append(token_type_ids)

    # # 将数据转换成张量
    all_input_ids = torch.tensor(all_input_ids, dtype=torch.long)
    all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long)
    all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)

    return dataset


def predict(pred_config):
    # # 加载参数
    args = get_args(pred_config)
    device = get_device(pred_config)
    # 加载模型
    model = load_model(pred_config, args, device)
    logger.info(args)
    # 获取意图标签id
    intent_label_lst = get_intent_labels(args)

    # 计算损失时，忽略的label序号
    pad_token_label_id = args.ignore_index
    # 加载分词模型
    tokenizer = load_tokenizer(args)
    # 读取输入文件
    lines = read_input_file(pred_config)
    # 将输入文件转化为TensorDataset
    dataset = convert_input_file_to_tensor_dataset(
        lines,
        pred_config,
        args,
        tokenizer,
    )

    # SequentialSampler:按顺序进行采样
    sampler = SequentialSampler(dataset)
    # 读取数据
    data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size)

    intent_preds = None
    # 循环预测每个batch
    for batch in tqdm(data_loader, desc="Predicting"):
        batch = tuple(t.to(device) for t in batch)
        # torch.no_grad():它包裹的不需要进行梯度计算
        with torch.no_grad():
            inputs = {"input_ids": batch[0],
                      "attention_mask": batch[1],
                      "intent_label_ids": None,}
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = batch[2]
            # 通过前向传播得到outputs
            outputs = model(**inputs)
            # 意图标签预测值
            intent_logits = outputs[0]

            # 如果意图标签存在
            if intent_preds is None:
                # detach()阻断反向传播，不再有梯度
                # numpy不能读取CUDA tensor 需要将它转化为 CPU tensor
                intent_preds = intent_logits.detach().cpu().numpy()
            # 如果意图标签不存在
            else:
                intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
    # 获取意图标签预测的索引
    intent_preds = np.argmax(intent_preds, axis=1)

    # 写入到文件中
    with open(pred_config.output_file, "w", encoding="utf-8") as f:
        for words, intent_pred in zip(lines, intent_preds):
            line = ""
            f.write("{}\n".format(intent_label_lst[intent_pred]))

    logger.info("Prediction Done!")


if __name__ == "__main__":
    # 初始化日志
    init_logger()
    # 建立解析对象
    parser = argparse.ArgumentParser()

    parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction")
    parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction")
    parser.add_argument("--model_dir", default="./atis_model", type=str, help="Path to save, load model")

    parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    # 属性给与args实例:把parser中设置的所有"add_argument"给返回到args子类实例当中，那么parser中增加的属性内容都会在args实例中，使用即可
    pred_config = parser.parse_args()
    # 预测
    predict(pred_config)
